// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using ILLink.Shared.TrimAnalysis;
using ILLink.Shared.TypeSystemProxy;
using Mono.Cecil;
using Mono.Cecil.Cil;
using Mono.Linker.Steps;
using MultiValue = ILLink.Shared.DataFlow.ValueSet<ILLink.Shared.DataFlow.SingleValue>;

namespace Mono.Linker.Dataflow
{
    sealed class ReflectionMethodBodyScanner : MethodBodyScanner
    {
        readonly MarkStep _markStep;
        MessageOrigin _origin;
        readonly FlowAnnotations _annotations;
        readonly ReflectionMarker _reflectionMarker;
        public readonly TrimAnalysisPatternStore TrimAnalysisPatterns;

        public static bool RequiresReflectionMethodBodyScannerForCallSite(LinkContext context, MethodReference calledMethod)
        {
            MethodDefinition? methodDefinition = context.TryResolve(calledMethod);
            if (methodDefinition == null)
                return false;

            var annotations = context.Annotations;
            var flowAnnotations = annotations.FlowAnnotations;
            return Intrinsics.GetIntrinsicIdForMethod(methodDefinition) > IntrinsicId.RequiresReflectionBodyScanner_Sentinel ||
                flowAnnotations.RequiresDataFlowAnalysis(methodDefinition) ||
                GenericArgumentDataFlow.RequiresGenericArgumentDataFlow(flowAnnotations, calledMethod) ||
                annotations.DoesMethodRequireUnreferencedCode(methodDefinition, out _) ||
                IsPInvokeDangerous(methodDefinition, context, out _);
        }

        public static bool RequiresReflectionMethodBodyScannerForMethodBody(LinkContext context, MethodDefinition methodDefinition)
        {
            return Intrinsics.GetIntrinsicIdForMethod(methodDefinition) > IntrinsicId.RequiresReflectionBodyScanner_Sentinel ||
                context.Annotations.FlowAnnotations.RequiresDataFlowAnalysis(methodDefinition);
        }

        public static bool RequiresReflectionMethodBodyScannerForAccess(LinkContext context, FieldReference field)
        {
            FieldDefinition? fieldDefinition = context.TryResolve(field);
            if (fieldDefinition == null)
                return false;

            var flowAnnotations = context.Annotations.FlowAnnotations;
            return GenericArgumentDataFlow.RequiresGenericArgumentDataFlow(flowAnnotations, field) ||
                flowAnnotations.RequiresDataFlowAnalysis(fieldDefinition);
        }

        public static bool RequiresReflectionMethodBodyScannerForAccess(LinkContext context, TypeReference type)
        {
            TypeDefinition? typeDefinition = context.TryResolve(type);
            if (typeDefinition == null)
                return false;

            var annotations = context.Annotations;
            return GenericArgumentDataFlow.RequiresGenericArgumentDataFlow(annotations.FlowAnnotations, type)
                || annotations.TryGetLinkerAttribute<RequiresUnreferencedCodeAttribute>(typeDefinition, out _);
        }

        public ReflectionMethodBodyScanner(LinkContext context, MarkStep parent, MessageOrigin origin)
            : base(context)
        {
            _markStep = parent;
            _origin = origin;
            _annotations = context.Annotations.FlowAnnotations;
            _reflectionMarker = new ReflectionMarker(context, parent, enabled: false);
            TrimAnalysisPatterns = new TrimAnalysisPatternStore(MultiValueLattice, context);
        }

        public override void InterproceduralScan(MethodIL methodIL)
        {
            base.InterproceduralScan(methodIL);

            var reflectionMarker = new ReflectionMarker(_context, _markStep, enabled: true);
            TrimAnalysisPatterns.MarkAndProduceDiagnostics(reflectionMarker, _markStep);
        }

        protected override void Scan(MethodIL methodIL, ref InterproceduralState interproceduralState)
        {
            _origin = new MessageOrigin(methodIL.Method);
            base.Scan(methodIL, ref interproceduralState);
        }

        protected override void WarnAboutInvalidILInMethod(MethodIL methodIl, int ilOffset)
        {
            // Serves as a debug helper to make sure valid IL is not considered invalid.
            //
            // The .NET Native compiler used to warn if it detected invalid IL during treeshaking,
            // but the warnings were often triggered in autogenerated dead code of a major game engine
            // and resulted in support calls. No point in warning. If the code gets exercised at runtime,
            // an InvalidProgramException will likely be raised.
            Debug.Fail("Invalid IL or a bug in the scanner");
        }

        protected override ValueWithDynamicallyAccessedMembers GetMethodParameterValue(ParameterProxy parameter)
            => GetMethodParameterValue(parameter, _context.Annotations.FlowAnnotations.GetParameterAnnotation(parameter));

        MethodParameterValue GetMethodParameterValue(ParameterProxy parameter, DynamicallyAccessedMemberTypes dynamicallyAccessedMemberTypes)
            => _annotations.GetMethodParameterValue(parameter, dynamicallyAccessedMemberTypes);

        protected override MultiValue GetFieldValue(FieldReference field)
        {
            ProcessGenericArgumentDataFlow(field);
            return _annotations.GetFieldValue(field);
        }

        protected override MethodReturnValue GetReturnValue(MethodIL methodIL) => _annotations.GetMethodReturnValue(methodIL.Method, isNewObj: false);

        private void HandleStoreValueWithDynamicallyAccessedMembers(ValueWithDynamicallyAccessedMembers targetValue, Instruction operation, MultiValue sourceValue, int? parameterIndex)
        {
            if (targetValue.DynamicallyAccessedMemberTypes != 0)
            {
                _origin = _origin.WithInstructionOffset(operation.Offset);
                TrimAnalysisPatterns.Add(new TrimAnalysisAssignmentPattern(sourceValue, targetValue, _origin, parameterIndex));
            }
        }

        protected override void HandleStoreField(MethodIL methodIL, FieldValue field, Instruction operation, MultiValue valueToStore, int? parameterIndex)
            => HandleStoreValueWithDynamicallyAccessedMembers(field, operation, valueToStore, parameterIndex);

        protected override void HandleStoreParameter(MethodIL methodIL, MethodParameterValue parameter, Instruction operation, MultiValue valueToStore, int? parameterIndex)
            => HandleStoreValueWithDynamicallyAccessedMembers(parameter, operation, valueToStore, parameterIndex);

        protected override void HandleReturnValue(MethodIL methodIL, MethodReturnValue returnValue, Instruction operation, MultiValue valueToStore)
            => HandleStoreValueWithDynamicallyAccessedMembers(returnValue, operation, valueToStore, null);

        protected override void HandleTypeTokenAccess(MethodIL methodIL, int offset, TypeReference accessedType)
        {
            // Note that ldtoken alone is technically a reflection access to the type
            // it doesn't lead to full reflection marking of the type
            // since we implement full dataflow for type values and accesses to them.
            _origin = _origin.WithInstructionOffset(offset);

            // Only check for generic instantiations.
            ProcessGenericArgumentDataFlow(accessedType);
        }

        protected override void HandleMethodTokenAccess(MethodIL methodIL, int offset, MethodReference accessedMethod)
        {
            _origin = _origin.WithInstructionOffset(offset);

            ProcessGenericArgumentDataFlow(accessedMethod);
        }

        protected override void HandleFieldTokenAccess(MethodIL methodIL, int offset, FieldReference accessedField)
        {
            _origin = _origin.WithInstructionOffset(offset);

            ProcessGenericArgumentDataFlow(accessedField);
        }

        public override MultiValue HandleCall(MethodIL callingMethodIL, MethodReference calledMethod, Instruction operation, ValueNodeList methodParams)
        {
            var reflectionProcessed = _markStep.ProcessReflectionDependency(callingMethodIL.Body, operation);
            if (reflectionProcessed)
            {
                return UnknownValue.Instance;
            }

            Debug.Assert(callingMethodIL.Method == _origin.Provider);
            var calledMethodDefinition = _context.TryResolve(calledMethod);
            if (calledMethodDefinition == null)
            {
                return UnknownValue.Instance;
            }

            _origin = _origin.WithInstructionOffset(operation.Offset);

            MultiValue instanceValue;
            ImmutableArray<MultiValue> arguments;
            if (calledMethodDefinition.HasImplicitThis())
            {
                instanceValue = methodParams[0];
                arguments = methodParams.Skip(1).ToImmutableArray();
            }
            else
            {
                instanceValue = MultiValueLattice.Top;
                arguments = methodParams.ToImmutableArray();
            }

            TrimAnalysisPatterns.Add(new TrimAnalysisMethodCallPattern(
                operation,
                calledMethod,
                instanceValue,
                arguments,
                _origin
            ));

            ProcessGenericArgumentDataFlow(calledMethod);

            var diagnosticContext = new DiagnosticContext(_origin, diagnosticsEnabled: false, _context);
            return HandleCall(
                operation,
                calledMethod,
                instanceValue,
                arguments,
                diagnosticContext,
                _reflectionMarker,
                _context,
                _markStep);
        }

        public static MultiValue HandleCall(
            Instruction operation,
            MethodReference calledMethod,
            MultiValue instanceValue,
            ImmutableArray<MultiValue> argumentValues,
            DiagnosticContext diagnosticContext,
            ReflectionMarker reflectionMarker,
            LinkContext context,
            MarkStep markStep)
        {
            var origin = diagnosticContext.Origin;
            if (!MethodProxy.TryCreate(calledMethod, context, out MethodProxy? calledMethodProxy))
            {
                Debug.Fail("Should only be called for resolvable methods");
                return UnknownValue.Instance;
            }
            var calledMethodDefinition = calledMethodProxy.Value.Definition;
            var callingMethodDefinition = origin.Provider as MethodDefinition;
            Debug.Assert(callingMethodDefinition != null);

            bool requiresDataFlowAnalysis = context.Annotations.FlowAnnotations.RequiresDataFlowAnalysis(calledMethodDefinition);
            bool isNewObj = operation.OpCode.Code == Code.Newobj;
            var annotatedMethodReturnValue = context.Annotations.FlowAnnotations.GetMethodReturnValue(calledMethodProxy.Value, isNewObj);
            Debug.Assert(requiresDataFlowAnalysis || annotatedMethodReturnValue.DynamicallyAccessedMemberTypes == DynamicallyAccessedMemberTypes.None);

            var handleCallAction = new HandleCallAction(context, operation, markStep, reflectionMarker, diagnosticContext, callingMethodDefinition);
            var intrinsicId = Intrinsics.GetIntrinsicIdForMethod(calledMethodProxy.Value);
            if (!handleCallAction.Invoke(calledMethodProxy.Value, instanceValue, argumentValues, intrinsicId, out MultiValue methodReturnValue))
                throw new NotImplementedException($"Unhandled intrinsic: {intrinsicId}");
            return methodReturnValue;
        }

        static bool IsComInterop(IMarshalInfoProvider marshalInfoProvider, TypeReference parameterType, LinkContext context)
        {
            // This is best effort. One can likely find ways how to get COM without triggering these alarms.
            // AsAny marshalling of a struct with an object-typed field would be one, for example.

            // This logic roughly corresponds to MarshalInfo::MarshalInfo in CoreCLR,
            // not trying to handle invalid cases and distinctions that are not interesting wrt
            // "is this COM?" question.

            NativeType nativeType = NativeType.None;
            if (marshalInfoProvider.HasMarshalInfo)
            {
                nativeType = marshalInfoProvider.MarshalInfo.NativeType;
            }

            if (nativeType == NativeType.IUnknown || nativeType == NativeType.IDispatch || nativeType == NativeType.IntF)
            {
                // This is COM by definition
                return true;
            }

            if (nativeType == NativeType.None)
            {
                if (parameterType.IsPointer)
                {
                    // Pointer types are passed without marshalling
                    return false;
                }

                // Resolve will look at the element type
                var parameterTypeDef = context.TryResolve(parameterType);

                if (parameterTypeDef != null)
                {
                    if (parameterTypeDef.IsTypeOf(WellKnownType.System_Array))
                    {
                        // System.Array marshals as IUnknown by default
                        return true;
                    }
                    else if (parameterTypeDef.IsTypeOf(WellKnownType.System_String) ||
                        parameterTypeDef.IsTypeOf("System.Text", "StringBuilder"))
                    {
                        // String and StringBuilder are special cased by interop
                        return false;
                    }

                    if (parameterTypeDef.IsValueType)
                    {
                        // Value types don't marshal as COM
                        return false;
                    }
                    else if (parameterTypeDef.IsInterface)
                    {
                        // Interface types marshal as COM by default
                        return true;
                    }
                    else if (parameterTypeDef.IsMulticastDelegate())
                    {
                        // Delegates are special cased by interop
                        return false;
                    }
                    else if (parameterTypeDef.IsSubclassOf("System.Runtime.InteropServices", "CriticalHandle", context))
                    {
                        // Subclasses of CriticalHandle are special cased by interop
                        return false;
                    }
                    else if (parameterTypeDef.IsSubclassOf("System.Runtime.InteropServices", "SafeHandle", context))
                    {
                        // Subclasses of SafeHandle are special cased by interop
                        return false;
                    }
                    else if (!parameterTypeDef.IsSequentialLayout && !parameterTypeDef.IsExplicitLayout)
                    {
                        // Rest of classes that don't have layout marshal as COM
                        return true;
                    }
                }
            }

            return false;
        }

        private void ProcessGenericArgumentDataFlow(MethodReference method)
        {
            // We mostly need to validate static methods and generic methods
            // Instance non-generic methods on reference types don't need validation
            // because the creation of the instance is the place where the validation will happen.
            if (_context.TryResolve(method) is not MethodDefinition methodDefinition)
                return;

            if (!methodDefinition.IsStatic && !method.IsGenericInstance && !methodDefinition.IsConstructor && !methodDefinition.DeclaringType.IsValueType)
                return;

            if (GenericArgumentDataFlow.RequiresGenericArgumentDataFlow(_annotations, method))
            {
                TrimAnalysisPatterns.Add(new TrimAnalysisGenericInstantiationAccessPattern(method, _origin));
            }
        }

        private void ProcessGenericArgumentDataFlow(FieldReference field)
        {
            // We only need to validate static field accesses, instance field accesses don't need generic parameter validation
            // because the create of the instance would do that instead.
            if (_context.TryResolve(field) is not FieldDefinition fieldDefinition)
                return;

            if (!fieldDefinition.IsStatic)
                return;

            if (GenericArgumentDataFlow.RequiresGenericArgumentDataFlow(_annotations, field))
            {
                TrimAnalysisPatterns.Add(new TrimAnalysisGenericInstantiationAccessPattern(field, _origin));
            }
        }

        private void ProcessGenericArgumentDataFlow(TypeReference type)
        {
            if (type.IsGenericInstance && _annotations.HasGenericParameterAnnotation(type))
            {
                TrimAnalysisPatterns.Add(new TrimAnalysisGenericInstantiationAccessPattern(type, _origin));
            }
        }

        internal static bool IsPInvokeDangerous(MethodDefinition methodDefinition, LinkContext context, out bool comDangerousMethod)
        {
            // The method in ILLink only detects one condition - COM Dangerous, but it's structured like this
            // so that the code looks very similar to AOT which has more than one condition.

            if (!methodDefinition.IsPInvokeImpl)
            {
                comDangerousMethod = false;
                return false;
            }

            comDangerousMethod = IsComInterop(methodDefinition.MethodReturnType, methodDefinition.ReturnType, context);
#pragma warning disable RS0030 // MethodDefinition.Parameters is banned. Here we iterate through the parameters and don't need to worry about the 'this' parameter.
            foreach (ParameterDefinition pd in methodDefinition.Parameters)
            {
                comDangerousMethod |= IsComInterop(pd, pd.ParameterType, context);
            }
#pragma warning restore RS0030

            return comDangerousMethod;
        }
    }
}
